import pysam
import os, subprocess
import itertools
import operator
from collections import defaultdict, OrderedDict
import re
import shutil
import gzip
from itertools import product, combinations
from Bio import SeqIO
import sys, argparse
try:
   import cPickle as pickle
except:
   import pickle

def string_hamming_distance(str1, str2):
    """
    Fast hamming distance over 2 strings (known to be of same length).
    In information theory, the Hamming distance between two strings of equal
    length is the number of positions at which the corresponding symbols
    are different.
    eg "karolin" and "kathrin" is 3.
    """
    return sum(list(map(operator.ne, str1, str2)))
def to_fastq(name, seq, qual):
    """
    Return string that can be written to fastQ file
    """
    return '@'+ name +'\n'+ seq+'\n+\n'+ qual +'\n'

def sort_gzfastq(args):
    f1 = gzip.open(args.r1gz, "rt")
    f2 = gzip.open(args.r2gz, "rt")

    count=0
    f_meta = gzip.open(args.path+args.sample+'_meta.fastq.gz','wb')
    f_bio = gzip.open(args.path+args.sample+'_bio.fastq.gz','wb')
    gata30 = 'GTGAGTGATGGTTGAGGATGTGTGGAGATA' #30nt
    gate_pcr= 'GGAGTTGGAGTGAGTGGATGAGTGATG' #27nt
    fail_one =0
    no_meta = 0
    i = 0

    sys.stderr.write('Sorting files into bio and meta')
    for r1, r2 in zip(SeqIO.parse(f1, "fastq"),SeqIO.parse(f2, "fastq")):
        if r1.seq[:40].find(gata30[2:11]) !=-1 or r1.seq[:40].find(gata30[20:]) !=-1:
            if r2.seq[0:30].find(gate_pcr[1:10]) != -1 or r2.seq[0:30].find(gate_pcr[11:]) != -1:
                f_meta.write(r1.format("fastq").encode())
                f_bio.write(r2.format("fastq").encode())
                count +=1

            elif r2.seq[:40].find(gata30[1:10]) !=-1 or r2.seq[:40].find(gata30[20:]) !=-1:
                if r1.seq[0:30].find(gate_pcr[2:11]) != -1 or r1.seq[0:30].find(gate_pcr[12:]) != -1:
                    f_meta.write(r2.format("fastq").encode())
                    f_bio.write(r1.format("fastq").encode())
                    count +=1
                else:
                    fail_one += 1
        elif r2.seq[:40].find(gata30[1:10]) !=-1 or r2.seq[:40].find(gata30[20:]) !=-1:
                if r1.seq[0:30].find(gate_pcr[2:11]) != -1 or r1.seq[0:30].find(gate_pcr[12:]) != -1:
                    f_meta.write(r2.format("fastq").encode())
                    f_bio.write(r1.format("fastq").encode())
                    count +=1
                else:
                    fail_one += 1
        else:
            no_meta += 1
        i += 1
        #if i == 1000:
        #  break
    f1.close()
    f2.close()
    f_meta.close()
    f_bio.close()
    sys.stderr.write(f'There are total {i} reads, where {count} of them are sorted into output files, {fail_one} fail one side, and {no_meta} pair dont fit meta primer')

def trim_bio_cdna(args):
    ## split ab_oligo and cDNA, trim GATE7N from cDNA bioreads, use meta seq as name
    r1_gunzip = subprocess.Popen("gzip --stdout -d %s" % (args.path+args.sample+'_bio.fastq.gz'), shell=True, stdout=subprocess.PIPE)
    r1_stream = r1_gunzip.stdout
    r2_gunzip = subprocess.Popen("gzip --stdout -d %s" % (args.path+args.sample+'_meta.fastq.gz'), shell=True, stdout=subprocess.PIPE)
    r2_stream = r2_gunzip.stdout
    #r1_stream= open(args.path+args.sample+'_bio.fastq','r')
    #r2_stream= open(args.path+args.sample+'_meta.fastq','r')

    gata30= 'GTGAGTGATGGTTGAGGATGTGTGGAGATA'
    #gate_ab='TGTTCTGTAGTGACTGCTACTTACTTCTGCG' #GGA GTT GGA GTG AGT GGA TGA (21) TGT TCT GT a gtg act gct act tac t tc tgcg (21+31)
    gate_pcr = 'GGAGTTGGAGTGAGTGGATGAGTGATG'
#    oligo_reads = 0
    total_reads = 0
    cdna_reads = 0
    paired_reads = 0
    no_A_tail =0
    not_cdna =0
    cDNA_too_short = 0
#    f_oligo = gzip.open(args.path+args.sample+'_AbOligo.fastq.gz','wb')
    f_cdna = gzip.open(args.path+args.sample+'_cDNA.fastq.gz','wb')
    try:
        while True:
            total_reads += 1
            #Read 4 lines from each FastQ
            next(r1_stream,'0')#Read name
            r1_seq_ = next(r1_stream,'0').rstrip() #Read seq
            next(r1_stream,'0') #+ line
            r1_qual_ = next(r1_stream,'0').rstrip() #Read qual

            next(r2_stream,'0') #Read name
            r2_seq_ = next(r2_stream,'0').rstrip() #meta Read seq
            next(r2_stream,'0') #+ line
            next(r2_stream,'0').rstrip() #Read qual

            if r1_seq_ == '0':
              break
            else:
              r1_seq= r1_seq_.decode("utf-8")
              r2_seq = r2_seq_.decode("utf-8")
              r1_qual= r1_qual_.decode("utf-8")
              trim_position = r1_seq.find('GAGTGATG') # last 8nt of gatepcr
              if trim_position > 15 and trim_position < 35:                        #cDNA
                            no_gatepcr_seq = r1_seq[trim_position+8+7:]
                            no_gatepcr_qual = r1_qual[trim_position+8+7:]
                            dAtail = re.finditer('A{6,}',no_gatepcr_seq)   # trim dA tailing

                            starts=[]
                            for i in dAtail:
                                starts.append(i.start())
                            if bool(starts) == True:
                              trim_A=starts[0]
                              seq = no_gatepcr_seq[0:trim_A]
                              if len(seq) > 16:
                                cdna_reads += 1
                                qual = no_gatepcr_qual[0:trim_A]
                                name = r2_seq # meta read as name
                                output_lines = to_fastq(name, seq, qual)
                                f_cdna.write(output_lines.encode())
                              else:
                                cDNA_too_short += 1
                            else:
                                no_A_tail += 1
                                seq = no_gatepcr_seq
                                if len(seq) > 16:
                                    cdna_reads += 1
                                    qual = no_gatepcr_qual
                                    name = r2_seq # meta read as name
                                    output_lines = to_fastq(name, seq, qual)
                                    f_cdna.write(output_lines.encode())
                                else:
                                    cDNA_too_short += 1
              else:
                        not_cdna +=1

    except StopIteration:
      exit()
#    f_oligo.close()
    f_cdna.close()
    r1_stream.close()
    r2_stream.close()

    sys.stderr.write(f'{cDNA_too_short}too short, {not_cdna} not cdna read, {no_A_tail} no A tail, and {cdna_reads} cDNA reads out of {total_reads} total')

def string_hamming_distance(str1, str2):
    """
    Fast hamming distance over 2 strings (known to be of same length).
    In information theory, the Hamming distance between two strings of equal
    length is the number of positions at which the corresponding symbols
    are different.
    eg "karolin" and "kathrin" is 3.
    """
    return sum(list(map(operator.ne, str1, str2)))

def gel_barcode_list_neighborhood(gel_barcode_list):
    gel_barcode_list_neighborhood = build_barcode_neighborhoods(gel_barcode_list, False)
    return gel_barcode_list_neighborhood
def build_barcode_neighborhoods(barcode_file, expect_reverse_complement=True):
    """
    Given a set of barcodes, produce sequences which can unambiguously be
    mapped to these barcodes, within 2 substitutions. If a sequence maps to
    multiple barcodes, get rid of it. However, if a sequences maps to a bc1 with
    1change and another with 2changes, keep the 1change mapping.
    """

    # contains all mutants that map uniquely to a barcode
    clean_mapping = dict()

    # contain single or double mutants
    mapping1 = defaultdict(set)
    mapping2 = defaultdict(set)

    #Build the full neighborhood and iterate through barcodes
    with open(barcode_file, 'rU') as f:
        # iterate through each barcode (rstrip cleans string of whitespace)
        for line in f:
            barcode = capital_seq(line.rstrip())
            #if expect_reverse_complement:
            #    barcode = rev_comp(line.rstrip())

            # each barcode obviously maps to itself uniquely
            clean_mapping[barcode] = barcode

            # for each possible mutated form of a given barcode, either add
            # the origin barcode into the set corresponding to that mutant or
            # create a new entry for a mutant not already in mapping1
            # eg: barcodes CATG and CCTG would be in the set for mutant CTTG
            # but only barcode CATG could generate mutant CANG
            for n in seq_neighborhood(barcode, 1):
                mapping1[n].add(barcode)

            # same as above but with double mutants
            for n in seq_neighborhood(barcode, 2):
                mapping2[n].add(barcode)

    # take all single-mutants and find those that could only have come from one
    # specific barcode
    for k, v in mapping1.items():
        if k not in clean_mapping:
            if len(v) == 1:
                clean_mapping[k] = list(v)[0]

    for k, v in mapping2.items():
        if k not in clean_mapping:
            if len(v) == 1:
                clean_mapping[k] = list(v)[0]
    del mapping1
    del mapping2
    return clean_mapping

___tbl = {'A':'T', 'T':'A', 'C':'G', 'G':'C', 'N':'N'}
#def rev_comp(seq):
#    return ''.join(___tbl[s] for s in seq[::-1])
___tb2 = {'a':'A', 'g':'G', 'c':'C', 't':'T', 'N':'N'}
def capital_seq(seq):
    return ''.join(___tb2[s] for s in seq)
def seq_neighborhood(seq, n_subs=1):
    """
    Given a sequence, yield all sequences within n_subs substitutions of
    that sequence by looping through each combination of base pairs within
    each combination of positions.
    """
    for positions in combinations(range(len(seq)), n_subs):
    # yields all unique combinations of indices for n_subs mutations
        for subs in product(*("ATGCN",)*n_subs):
        # yields all combinations of possible nucleotides for strings of length
        # n_subs
            seq_copy = list(seq)
            for p, s in zip(positions, subs):
                seq_copy[p] = s
            yield ''.join(seq_copy)
def from_fastq(handle):
    while True:
        name = next(handle).rstrip()[1:] #Read name
        seq = next(handle).rstrip() #Read seq
        next(handle) #+ line
        qual = next(handle).rstrip() #Read qual
        if not name or not seq or not qual:
            break
        yield name, seq, qual

def trim_meta(args): #args.path,args.sample, args.biotype, args.bc1, args.bc2, args.bc3
  r1_gunzip = subprocess.Popen("gzip --stdout -d %s" % (args.path+args.sample+'_'+args.biotype+'.fastq'), shell=True, stdout=subprocess.PIPE)
  stream = r1_gunzip.stdout
  #stream = open(args.path+args.sample+'_'+args.biotype+'.fastq','r')
  f= gzip.open(args.path+args.sample+'_'+args.biotype+'_bcumi.fastq.gz','wb')

  valid_bc1s = gel_barcode_list_neighborhood(args.bc1)
  valid_bc2s = gel_barcode_list_neighborhood(args.bc2)
  valid_bc3s = gel_barcode_list_neighborhood(args.bc3)
  bc1_fail = 0
  bc2_fail = 0
  bc3_fail = 0
  umi_fail = 0
  total_reads = 0
  saved_reads =0
  while True:
    meta = next(stream,0)
    seq = next(stream,0)
    next(stream,0)
    qual = next(stream,0)
    total_reads += 1
    if meta == 0:
      break
    else:
      meta = meta.decode("utf-8")
      seq = seq.decode("utf-8")
      qual = qual.decode("utf-8")
      pos=meta.find('GAGATA') # last 6nt of gata30
      if pos < 35 and pos > 5:
          bc1 = meta[pos + 6 : pos + 16]
          bc2 = meta[pos+20:pos+30]
          bc3 = meta[pos+34 : pos+44]
          umi = meta[pos+44:pos+52]

          #if valid_bc1s and valid_bc2s:
            # Check if BC1 and BC2 can be mapped to expected barcodes
          if 'N' in umi:
              umi_fail += 1
          elif bc1 in valid_bc1s:
            if bc2 in valid_bc2s:
              if bc3 in valid_bc3s:
                # BC1 might be a neighboring BC, rather than a valid BC itself.
                bc1 = valid_bc1s[bc1]
                bc2 = valid_bc2s[bc2]
                bc3 = valid_bc3s[bc3]

                bc = bc1+bc2+bc3
                name = bc+':'+umi
                out = to_fastq(name.rstrip(),seq.rstrip(),qual.rstrip())
                f.write(out.encode())
                saved_reads += 1
              else:
                bc3_fail += 1
            else:
              bc2_fail += 1
          else:
                bc1_fail += 1
                #print(bc1)
  stream.close()
  sys.stderr.write(f'{bc1_fail} bc1 fail, {bc2_fail} bc2 fail,{bc3_fail} bc3 fail and {umi_fail} umi fail from {total_reads} total, saved {saved_reads} reads')

# put all reads with a valid sorted barcode into separate fastq file
from collections import defaultdict
def valid_bc_sep_file(args): #fastq, list_of_valid_bc, path, sample
  #gunzip = subprocess.Popen("gzip --stdout -d %s" % (args.bcumi_fastq), shell=True, stdout=subprocess.PIPE)
  #stream = gunzip.stdout
  stream = open(args.fastq,'r')
  import pickle
  #valid_bc = pickle.load(open(args.validBC_list,'rb')) #(comment out this line for cases where all barcodes are included.)
  total_reads = 0
  #bc_set=set(valid_bc)   #(comment out this line for cases where all barcodes are included.)
  sum_dict=defaultdict(list) #{bc:umi:[#,seq],bc2:}
  while total_reads < 500000000:
        meta = next(stream,0)
        seq = next(stream,0)
        next(stream,0)
        qual = next(stream,0)
        if meta == 0:
            break
        else:
          bc = meta[1:31]
          umi = meta.split(':')[1].rstrip()
          #if bc in bc_set:   # (remove this if arguement for cases where all barcodes are included. )
          total_reads += 1 
          file_name = '{}.fastq'.format(bc)
          f=open(args.path+file_name,'a')
          if sum_dict[bc] == []:
                sum_dict[bc] = defaultdict(list)
                sum_dict[bc][umi] = [1,seq]
                name = bc+':'+umi
                out = to_fastq(name.rstrip(),seq.rstrip(),qual.rstrip())
                #out = str(meta.rstrip()) +'\n'+ str(seq.rstrip())+'\n+\n'+ str(qual.rstrip()) +'\n'
                f.write(out)
                f.close()
          elif sum_dict[bc][umi] ==[]:
                sum_dict[bc][umi] = [1,seq]
                name = bc+':'+umi
                out = to_fastq(name.rstrip(),seq.rstrip(),qual.rstrip())
                #out = str(meta.rstrip()) +'\n'+ str(seq.rstrip())+'\n+\n'+ str(qual.rstrip()) +'\n'
                f.write(out)
                f.close()
          elif string_hamming_distance(seq,sum_dict[bc][umi][1]) <= 2: #umi duplicate
                sum_dict[bc][umi][0] += 1
          else:
                new_umi= umi +'_'+diff(seq,sum_dict[bc][umi][1])
                if sum_dict[bc][new_umi] == []:
                    sum_dict[bc][new_umi] = [1,seq]
                    name = bc+':'+new_umi
                    out = to_fastq(name.rstrip(),seq.rstrip(),qual.rstrip())
                    #out = str(meta.rstrip())+'_'+diff(seq,sum_dict[bc][umi][1]) +'\n'+ str(seq.rstrip())+'\n+\n'+ str(qual.rstrip()) +'\n'
                    f.write(out)
                    f.close()
                else:
                    sum_dict[bc][new_umi][0] += 1
          print(total_reads)
          #else:
            #print(f'{bc} not in bc_set')

  stream.close()
  # with open(args.path+args.sample+'sum_dict.pickle', 'wb') as handle:
        # pickle.dump(sum_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
  #print(f'{total_reads} reads from {len(valid_bc)} of barcodes are saved into file')

def diff(str1,str2):
    diffstr=''
    for a,b in zip(str1,str2):
      if a != b:
        diffstr=diffstr+a
    return diffstr

#compile RSEM genes.results
from numpy.lib.function_base import extract
import os
import pandas as pd

def rsem_df(args): #path, sample, fastq_dir
  dirname= args.fastq_dir
  ext = ('.genes.results')

  final_df=pd.DataFrame()
  i=0
  file_list = os.listdir(dirname)
  gene_dir = []
  for f in file_list:
    if f.endswith(ext):
        gene_dir.append(f)
        
  
  for f in gene_dir:
    df=pd.read_csv(dirname+f,sep='\t')
    test_df = df[['gene_id', 'expected_count']].copy()
    test_df = test_df.set_index('gene_id').T
    final_df=pd.concat([final_df,test_df],axis=0,ignore_index=True)
    print(i)
    i+=1


  final_df_= final_df.transpose()
  final_df_.to_csv(args.path+args.sample+'_cell_gene.csv')

if __name__=="__main__":

    import sys, argparse
    parser = argparse.ArgumentParser()

    parser.add_argument('-r1gz', type=str)
    parser.add_argument('-r2gz', type=str)
    parser.add_argument('-outpath', type=str)
    parser.add_argument('-sample', type=str)
    parser.add_argument('-path', type=str)
    parser.add_argument('-hamming_threshold', type=str, help='hamming_threshold_for_gate_pcr_matching')
    parser.add_argument('-biotype', type=str,help='cDNA or AbOligo')
    parser.add_argument('-bc1', type=str,help='bc1 file path')
    parser.add_argument('-bc2', type=str,help='bc1 file path')
    parser.add_argument('-bc3', type=str,help='bc1 file path')
    parser.add_argument('-feature_file', type=str)
    parser.add_argument('-sam_file', type=str)
    parser.add_argument('-fastq', type=str)
    parser.add_argument('-validBC_list', type=str)
    parser.add_argument('-fastq_dir', type=str)
    parser.add_argument('-l', '--libraries', type=str, help='[all] Library name(s) to work on. If blank, will iterate over all libraries in project.', nargs='?', default='')
    parser.add_argument('-r', '--runs', type=str, help='[all] Run name(s) to work on. If blank, will iterate over all runs in project.', nargs='?', default='')
    parser.add_argument('command', type=str, choices=['rsem_df','valid_bc_sep_file','sort_file','filter_bio_ab','filter_bio_cdna','filter_meta','cell_gene_matrix','chr_gene_dict','merge_csv','transpose'])



    args = parser.parse_args()
    if args.command == 'sort_file': #r1gz, r2gz, path, sample
        sort_gzfastq(args)

    # elif args.command == 'filter_bio_ab':
    #     trim_bio_ab(args) #args.path,args.sample, args.hamming_threshold

    elif args.command == 'filter_bio_cdna':
        trim_bio_cdna(args) #args.path,args.sample, args.hamming_threshold

    elif args.command == 'filter_meta': #args.path,args.sample, args.biotype, args.bc1, args.bc2, args.bc3
        trim_meta(args)

    # elif args.command == 'chr_gene_dict': # args.feature_file, args.path
    #     get_chr_gene_dict(args)

    elif args.command == 'cell_gene_matrix': #args.sam_file, path, sample
        bc_gene_dict= get_bc_gene_dict(args) #args.sam_file,
        cell_gene_dict=get_cell_gene_dict(bc_gene_dict) #args.path
        to_csv(args) #args.path

    elif args.command == 'valid_bc_sep_file':
        valid_bc_sep_file(args) #-fastq  -path -sample list_of_valid_bc

    elif args.command == 'rsem_df':
        rsem_df(args)